"""

    classification based ranking

"""
import json
import os
import random

from datasets import load_dataset
from torch.utils.data import Dataset

from .utils import webgpt_return_format


class WebGPTDataset(Dataset):
    def __init__(self, mode="train", index_cache="dataset/webgpt_train_idx.pt", additional_dataset=None) -> None:
        super().__init__()
        """
            mode : train or val, used for validation purpose, has nothing to do with original split
            additional_dataset : a list of jsonline format with idx, question and texts (generate candidates)
                idx : must match the index you iterate from comparison enumerate order
                question : for validation purpose
                texts : list of K generate results from the question prompt
        """
        os.makedirs("dataset", exist_ok=True)
        dataset = load_dataset("openai/webgpt_comparisons")
        self.dataset = []
        self.dataset_index = []
        for idx, row in enumerate(dataset["train"]):
            self.dataset.append(webgpt_return_format(row))

        # since this dataset was generated from 176B GPT-3
        # we needed some more sample generated from the starting model
        # since this model must rank model generated by GPT-3 being better than your starting model
        self.sample_additional = False
        if additional_dataset is not None:
            self.sample_additional = True
            self.additional = {}
            with open(additional_dataset, "r") as f:
                for line in f:
                    row = json.loads(line)
                    if row["idx"] in self.dataset_index:
                        self.additional[row["idx"]] = row["negatives"]
            if len(self.additional) != len(self.dataset_index):
                for match_idx in self.dataset_index:
                    if match_idx in self.additional:
                        continue

                    idx = match_idx - 900
                    while idx not in self.additional:
                        idx -= 1
                    self.additional[match_idx] = self.additional[idx]

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        row = self.dataset[index]
        if not self.sample_additional:
            return row["question"], row["pos"], row["neg"]

        gen_neg = random.choice(self.additional[self.dataset_index[index]])
        return row["question"], row["pos"], row["neg"], gen_neg
